In [2]:
import pandas as pd
import numpy as np
import altair as alt

import theme
from natsort import natsorted, natsort_keygen

alt.data_transformers.disable_max_rows()
Out[2]:
DataTransformerRegistry.enable('default')
In [3]:
mutation_effects = pd.read_csv('../results/combined_effects/combined_mutation_effects.csv')
mutation_effects.head()
Out[3]:
mutant struct_site 4o5n_aa 6ii9_aa 4kwm_aa 4o5n_aa_pdb_site 4o5n_aa_RSA 4o5n_aa_SS 4o5n_aa_chain 4kwm_aa_pdb_site ... h5_site h5_wt_aa h7_site h7_wt_aa h3_effect h3_effect_std h5_effect h5_effect_std h7_effect h7_effect_std
0 A 9 P - P 9.0 1.084277 - A -1.0 ... 9 K NaN NaN 0.0151 0.7225 0.0558 0.29180 NaN NaN
1 C 9 P - P 9.0 1.084277 - A -1.0 ... 9 K NaN NaN -0.4080 0.3850 -0.4245 0.02737 NaN NaN
2 D 9 P - P 9.0 1.084277 - A -1.0 ... 9 K NaN NaN 0.2361 0.2740 0.2039 0.07884 NaN NaN
3 E 9 P - P 9.0 1.084277 - A -1.0 ... 9 K NaN NaN -0.2463 0.8478 0.1713 0.10210 NaN NaN
4 F 9 P - P 9.0 1.084277 - A -1.0 ... 9 K NaN NaN 0.2061 0.3214 -0.8397 1.34800 NaN NaN

5 rows × 32 columns

In [4]:
site_effects = pd.read_csv('../results/combined_effects/combined_site_effects.csv')
site_effects.head()
Out[4]:
struct_site 4o5n_aa 6ii9_aa 4kwm_aa 4o5n_aa_pdb_site 4o5n_aa_RSA 4o5n_aa_SS 4o5n_aa_chain 4kwm_aa_pdb_site 4kwm_aa_RSA ... rmsd_h5h7 h3_site h3_wt_aa h5_site h5_wt_aa h7_site h7_wt_aa avg_h3_effect avg_h5_effect avg_h7_effect
0 9 P - P 9.0 1.084277 - A -1.0 1.140252 ... NaN 9.0 S 9 K NaN NaN -0.050776 -1.062932 NaN
1 10 G - G 10.0 0.150962 - A 0.0 0.175962 ... NaN 10.0 T 10 S NaN NaN -0.697911 -3.224739 NaN
2 11 A D D 11.0 0.050388 E A 1.0 0.097927 ... 2.886615 11.0 A 11 D 11 D -3.138280 -3.921267 -2.963026
3 12 T K Q 12.0 0.268605 E A 2.0 0.216889 ... 3.384350 12.0 T 12 Q 12 K -1.036219 -0.467449 -1.727747
4 13 L I I 13.0 0.000000 E A 3.0 0.000000 ... 2.549524 13.0 L 13 I 13 I -3.941050 -3.885729 -3.840728

5 rows × 28 columns

In [5]:
# Read in protein sequence identities
seq_identity = pd.read_csv('../results/sequence_identity/ha_sequence_identity.csv')
seq_identity.head()
Out[5]:
ha_x ha_y matches alignable_residues percent_identity
0 H3 H5 192.0 479.0 40.083507
1 H3 H7 229.0 483.0 47.412008
2 H5 H7 202.0 473.0 42.706131
In [6]:
h3_h7_scatter = alt.Chart(mutation_effects).mark_circle(
    size=25, opacity=0.3, color='#767676'
).encode(
    x=alt.X('h3_effect', title=['Effect on MDCK-SIAT1 entry', 'in H3 background']),
    y=alt.Y('h7_effect', title=['Effect on 293-a2,6 entry', 'in H7 background']),
    tooltip=['struct_site', 'mutant', 'h3_wt_aa', 'h7_wt_aa', 'h3_effect', 'h7_effect']
).properties(
    width=200,
    height=200,
    title='H3 vs. H7'
)

h3_h5_scatter = alt.Chart(mutation_effects).mark_circle(
    size=25, opacity=0.3, color='#767676'
).encode(
    x=alt.X('h3_effect', title=['Effect on MDCK-SIAT1 entry', 'in H3 background']),
    y=alt.Y('h5_effect', title=['Effect on 293T entry', 'in H5 background']),
    tooltip=['struct_site', 'mutant', 'h3_wt_aa', 'h5_wt_aa', 'h3_effect', 'h5_effect']
).properties(
    width=200,
    height=200,
    title='H3 vs. H5'
)

h5_h7_scatter = alt.Chart(mutation_effects).mark_circle(
    size=25, opacity=0.3, color='#767676'
).encode(
    x=alt.X('h5_effect', title=['Effect on 293T entry', 'in H5 background']),
    y=alt.Y('h7_effect', title=['Effect on 293-a2,6 entry', 'in H7 background']),
    tooltip=['struct_site', 'mutant', 'h5_wt_aa', 'h7_wt_aa', 'h5_effect', 'h7_effect']
).properties(
    width=200,
    height=200,
    title='H5 vs. H7'
)

h3_h7_scatter | h3_h5_scatter | h5_h7_scatter
Out[6]:
In [7]:
def scatter_and_density_plot(df, ha_x, ha_y, colors):
    r_value = df[f'avg_{ha_x}_effect'].corr(df[f'avg_{ha_y}_effect'])
    r_text = f"r = {r_value:.2f}"

    identity_line = alt.Chart(pd.DataFrame({'x': [-5, 0.3], 'y': [-5, 0.3]})).mark_line(
        strokeDash=[6, 6],
        color='black'
    ).encode(
        x='x',
        y='y'
    )

    df = df.assign(
        same_wildtype= lambda x: np.where(
            x[f'{ha_x}_wt_aa'] == x[f'{ha_y}_wt_aa'],
            'Amino acid conserved',
            'Amino acid changed'
        ),
    )

    scatter = alt.Chart(df).mark_circle(
        size=35, opacity=1, stroke='black', strokeWidth=0.5
    ).encode(
        x=alt.X(f'avg_{ha_x}_effect', title=['Mean effect on cell entry', f'in {ha_x.upper()} background']),
        y=alt.Y(f'avg_{ha_y}_effect', title=['Mean effect on cell entry', f'in {ha_y.upper()} background']),
        color=alt.Color(
            'same_wildtype:N', 
            scale=alt.Scale(domain=list(colors.keys()), range=list(colors.values())),
        ),
        tooltip=['struct_site', f'{ha_x}_wt_aa', f'{ha_y}_wt_aa', f'avg_{ha_x}_effect', f'avg_{ha_y}_effect']
    ).properties(
        width=175,
        height=175,
    )

    r_label = alt.Chart(pd.DataFrame({'text': [r_text]})).mark_text(
        align='left',
        baseline='top',
        fontSize=16,
        fontWeight='normal',
        color='black'
    ).encode(
        text='text:N',
        x=alt.value(5), 
        y=alt.value(5)
    )

    x_density = alt.Chart(df).transform_density(
        density=f'avg_{ha_x}_effect',
        bandwidth=0.3,
        groupby=['same_wildtype'],
        extent=[df[f'avg_{ha_x}_effect'].min(), df[f'avg_{ha_x}_effect'].max()],
        counts=True,
        steps=200
    ).mark_area(opacity=0.6, color='black', strokeWidth=1).encode(
        alt.X('value:Q', axis=alt.Axis(labels=False, title=None, ticks=False)),
        alt.Y('density:Q', title='Density').stack(None),
        color=alt.Color(
            'same_wildtype:N', 
            title=None,
            scale=alt.Scale(domain=list(colors.keys()), range=list(colors.values())),
        ),
    ).properties(
        width=175,
        height=50
    )

    y_density = alt.Chart(df).transform_density(
        density=f'avg_{ha_y}_effect',
        bandwidth=0.3,
        groupby=['same_wildtype'],
        extent=[df[f'avg_{ha_y}_effect'].min(), df[f'avg_{ha_y}_effect'].max()],
        counts=True,
        steps=200
    ).mark_area(opacity=0.6, color='black', strokeWidth=1, orient='horizontal').encode(
        alt.Y('value:Q', axis=alt.Axis(labels=False, title=None, ticks=False)),
        alt.X('density:Q', title='Density').stack(None),
        color=alt.Color(
            'same_wildtype:N', 
            title=None,
            scale=alt.Scale(domain=list(colors.keys()), range=list(colors.values())),
        ),
    ).properties(
        width=50,
        height=175
    )

    marginal_plot = alt.vconcat(
        x_density,
        alt.hconcat(
            (scatter + identity_line + r_label),
            y_density
        )
    )
    return marginal_plot

colors = {
    'Amino acid changed' : '#5484AF',
    'Amino acid conserved' : '#E04948'
}
p1 = scatter_and_density_plot(site_effects, 'h3', 'h5', colors=colors)
p2 = scatter_and_density_plot(site_effects, 'h3', 'h7', colors=colors)
p3 = scatter_and_density_plot(site_effects, 'h5', 'h7', colors=colors)

p1 | p2 | p3
Out[7]:

Calculate Jensen-Shannon Divergence¶

In [8]:
def kl_divergence(p, q):
    return np.sum(p * np.log(p / q))

def compute_js_divergence_per_site(df, ha_x, ha_y, site_col="struct_site", min_mutations=15):
    """Compute JS divergence at each site and merge it back to the dataframe."""
    js_per_site = {}

    for site, group in df.groupby(site_col):
        valid = group.dropna(subset=[f'{ha_x}_effect', f'{ha_y}_effect'])
        js_div = np.nan

        if len(valid) >= min_mutations:
            p = np.exp(valid[f'{ha_x}_effect'].values)
            q = np.exp(valid[f'{ha_y}_effect'].values)

            p /= p.sum()
            q /= q.sum()

            m = 0.5 * (p + q)
            js_div = 0.5 * (kl_divergence(p, m) + kl_divergence(q, m))

        js_per_site[site] = js_div

    # Create a column with the JS divergence duplicated across each row at the same site
    df = df.copy()
    col_name = f"JS_{ha_x}_vs_{ha_y}"
    df[col_name] = df[site_col].map(js_per_site)

    return df

js_df_h3_h7 = compute_js_divergence_per_site(mutation_effects, 'h3', 'h7', min_mutations=10)
js_df_h3_h5 = compute_js_divergence_per_site(mutation_effects, 'h3', 'h5', min_mutations=10)
js_df_h5_h7 = compute_js_divergence_per_site(mutation_effects, 'h5', 'h7', min_mutations=10)

Are epistatic shifts significant?¶

In [9]:
def compute_jsd_with_null(
    df,
    ha_x,
    ha_y,
    site_col="struct_site",
    min_mutations=15,
    n_bootstrap=1000,
    random_seed=42,
    jsd_threshold=0.02
):
    """
    Compute JS divergence with bootstrap null distribution for significance testing.

    The null distribution represents: "What JSD would I observe from measurement noise alone?"

    The null is generated by computing two separate null distributions:
    1. ha_x null: Sample ha_x twice with its measurement error, compute JSD
    2. ha_y null: Sample ha_y twice with its measurement error, compute JSD
    3. Take the mean of the two null distributions (balanced approach)

    This accounts for measurement noise from both experiments without assuming they have
    identical underlying effects. A significant result means the observed JSD is larger
    than what measurement noise alone could produce.

    Only sites with observed JSD > jsd_threshold are tested for significance.

    Parameters
    ----------
    df : pd.DataFrame
        Dataframe with mutation effects and effect_std columns
    ha_x, ha_y : str
        HA subtype names (e.g., 'h3', 'h5')
    site_col : str
        Column name for site identifier
    min_mutations : int
        Minimum number of mutations required at a site
    n_bootstrap : int
        Number of bootstrap iterations
    random_seed : int
        Random seed for reproducibility
    jsd_threshold : float
        Minimum JSD value for a site to be tested for significance.
        Sites with observed JSD <= threshold will have p_value = NaN.
        Default is 0.02.

    Returns
    -------
    pd.DataFrame
        DataFrame with columns:
        - struct_site: site identifier
        - JS_observed: observed JSD value
        - JS_null_mean: mean of null distribution (NaN if below threshold)
        - JS_null_std: standard deviation of null distribution (NaN if below threshold)
        - p_value: empirical p-value (NaN if below threshold)
        - n_mutations: number of mutations at site
        Sorted by struct_site using natural sorting.
    """
    np.random.seed(random_seed)

    def compute_jsd_vectorized(effects, std, n_bootstrap):
        """Vectorized computation of null JSD distribution."""
        n_mutations = len(effects)
        
        # Generate all bootstrap samples at once: shape (n_bootstrap, n_mutations)
        effects_1 = np.random.normal(
            loc=effects[np.newaxis, :],  # broadcast to (1, n_mutations)
            scale=std[np.newaxis, :],     # broadcast to (1, n_mutations)
            size=(n_bootstrap, n_mutations)
        )
        effects_2 = np.random.normal(
            loc=effects[np.newaxis, :],
            scale=std[np.newaxis, :],
            size=(n_bootstrap, n_mutations)
        )
        
        # Compute probabilities for all bootstraps at once
        p1 = np.exp(effects_1)
        p2 = np.exp(effects_2)
        
        # Normalize: divide each row by its sum
        p1 = p1 / p1.sum(axis=1, keepdims=True)
        p2 = p2 / p2.sum(axis=1, keepdims=True)
        
        # Compute mixture distribution
        m = 0.5 * (p1 + p2)
        
        # Compute KL divergences (vectorized)
        # KL(p||m) = sum(p * log(p/m))
        kl_p_m = np.sum(p1 * np.log(p1 / m), axis=1)
        kl_q_m = np.sum(p2 * np.log(p2 / m), axis=1)
        
        # JSD = 0.5 * (KL(p||m) + KL(q||m))
        jsd = 0.5 * (kl_p_m + kl_q_m)
        
        return jsd

    results = []

    for site, group in df.groupby(site_col):
        # Filter to valid mutations with both effects and stds
        valid = group.dropna(subset=[
            f'{ha_x}_effect', f'{ha_y}_effect',
            f'{ha_x}_effect_std', f'{ha_y}_effect_std'
        ])

        if len(valid) < min_mutations:
            continue

        # Get observed effects
        effects_x = valid[f'{ha_x}_effect'].values
        effects_y = valid[f'{ha_y}_effect'].values

        # Get standard deviations
        std_x = valid[f'{ha_x}_effect_std'].values
        std_y = valid[f'{ha_y}_effect_std'].values

        # Compute observed JSD between ha_x and ha_y
        p_obs = np.exp(effects_x)
        q_obs = np.exp(effects_y)
        p_obs /= p_obs.sum()
        q_obs /= q_obs.sum()
        m_obs = 0.5 * (p_obs + q_obs)
        jsd_obs = 0.5 * (kl_divergence(p_obs, m_obs) + kl_divergence(q_obs, m_obs))

        # Only compute null distribution if JSD exceeds threshold
        if jsd_obs <= jsd_threshold:
            results.append({
                'struct_site': site,
                'JS_observed': jsd_obs,
                'JS_null_mean': np.nan,
                'JS_null_std': np.nan,
                'p_value': np.nan,
                'n_mutations': len(valid),
                'null_distribution': None
            })
            continue

        # Vectorized bootstrap null distributions
        jsd_null_x = compute_jsd_vectorized(effects_x, std_x, n_bootstrap)
        jsd_null_y = compute_jsd_vectorized(effects_y, std_y, n_bootstrap)
        
        # Take the mean of the two nulls (balanced approach)
        jsd_null = (jsd_null_x + jsd_null_y) / 2

        # Compute empirical p-value (one-tailed test: is observed JSD greater than null?)
        p_value = np.mean(jsd_null >= jsd_obs)

        results.append({
            'struct_site': site,
            'JS_observed': jsd_obs,
            'JS_null_mean': jsd_null.mean(),
            'JS_null_std': jsd_null.std(),
            'p_value': p_value,
            'n_mutations': len(valid),
            'null_distribution': jsd_null  # Store for visualization
        })

    # Convert to DataFrame and sort by struct_site using natural sorting
    results_df = pd.DataFrame(results)
    results_df = results_df.sort_values('struct_site', key=natsort_keygen()).reset_index(drop=True)
    
    return results_df
In [10]:
# Compute JSD with null distributions for each comparison
jsd_with_pvals_h3_h5 = compute_jsd_with_null(
    js_df_h3_h5,
    'h3', 'h5',
    min_mutations=10,
    n_bootstrap=1000,
    jsd_threshold=0.02
)

jsd_with_pvals_h3_h7 = compute_jsd_with_null(
    js_df_h3_h7,
    'h3', 'h7',
    min_mutations=10,
    n_bootstrap=1000,
    jsd_threshold=0.02
)

jsd_with_pvals_h5_h7 = compute_jsd_with_null(
    js_df_h5_h7,
    'h5', 'h7',
    min_mutations=10,
    n_bootstrap=1000,
    jsd_threshold=0.02
)

# Apply multiple testing correction (Benjamini-Hochberg FDR)
# Only apply FDR to sites that were tested (non-NaN p-values)
from scipy.stats import false_discovery_control

def apply_fdr_with_threshold(df):
    """Apply FDR correction only to non-NaN p-values."""
    # Initialize q_value column with NaN
    df['q_value'] = np.nan
    
    # Get indices of non-NaN p-values
    tested_mask = df['p_value'].notna()
    
    if tested_mask.sum() > 0:
        # Apply FDR correction only to tested sites
        df.loc[tested_mask, 'q_value'] = false_discovery_control(df.loc[tested_mask, 'p_value'])
    
    return df

jsd_with_pvals_h3_h5 = apply_fdr_with_threshold(jsd_with_pvals_h3_h5)
jsd_with_pvals_h3_h7 = apply_fdr_with_threshold(jsd_with_pvals_h3_h7)
jsd_with_pvals_h5_h7 = apply_fdr_with_threshold(jsd_with_pvals_h5_h7)

# Report significant sites as fractions (out of ALL sites with JSD measurements)
print("Significant sites (H3 vs H5, q < 0.1):")
total_h3h5 = len(jsd_with_pvals_h3_h5)
sig_h3h5 = (jsd_with_pvals_h3_h5['q_value'] < 0.1).sum()
print(f"  {sig_h3h5} / {total_h3h5} sites ({sig_h3h5/total_h3h5:.2%})")

print("\nSignificant sites (H3 vs H7, q < 0.1):")
total_h3h7 = len(jsd_with_pvals_h3_h7)
sig_h3h7 = (jsd_with_pvals_h3_h7['q_value'] < 0.1).sum()
print(f"  {sig_h3h7} / {total_h3h7} sites ({sig_h3h7/total_h3h7:.2%})")

print("\nSignificant sites (H5 vs H7, q < 0.1):")
total_h5h7 = len(jsd_with_pvals_h5_h7)
sig_h5h7 = (jsd_with_pvals_h5_h7['q_value'] < 0.1).sum()
print(f"  {sig_h5h7} / {total_h5h7} sites ({sig_h5h7/total_h5h7:.2%})")
Significant sites (H3 vs H5, q < 0.1):
  294 / 467 sites (62.96%)

Significant sites (H3 vs H7, q < 0.1):
  207 / 467 sites (44.33%)

Significant sites (H5 vs H7, q < 0.1):
  189 / 431 sites (43.85%)
In [11]:
def plot_jsd(df, jsd_pvals_df, ha_x, ha_y, identity_df=None, alpha=0.1): 
    """
    Plot JSD values with significance coloring.
    
    Parameters
    ----------
    df : pd.DataFrame
        Main dataframe with mutation effects
    jsd_pvals_df : pd.DataFrame
        DataFrame with JSD p-values and q-values from compute_jsd_with_null
    identity_df : pd.DataFrame
        DataFrame with sequence identity information
    ha_x, ha_y : str
        HA subtype names
    alpha : float
        Significance threshold for q-value (default 0.1)
    """
    if identity_df is not None:
        result = identity_df.query(
            f'ha_x=="{ha_x.upper()}" and ha_y=="{ha_y.upper()}"'
        )
        shared_aai = result['percent_identity'].values[0] if len(result) > 0 else None
    else:
        shared_aai = None

    amino_acid_classification = {
        'F': 'Aromatic', 'Y': 'Aromatic', 'W': 'Aromatic',
        'N': 'Hydrophilic', 'Q': 'Hydrophilic', 'S': 'Hydrophilic', 'T': 'Hydrophilic',
        'A': 'Hydrophobic', 'V': 'Hydrophobic', 'I': 'Hydrophobic', 'L': 'Hydrophobic', 'M': 'Hydrophobic',
        'D': 'Negative', 'E': 'Negative',
        'R': 'Positive', 'H': 'Positive', 'K': 'Positive',
        'C': 'Special', 'G': 'Special', 'P': 'Special'
    }
    df['struct_site'] = df['struct_site'].astype(str)

    df = df.assign(
        mutant_type=lambda x: x['mutant'].map(amino_acid_classification)
    )

    # Merge significance data with site-level JSD data
    site_jsd_df = df[[
        'struct_site', f'{ha_x}_wt_aa', f'{ha_y}_wt_aa', 
        f'JS_{ha_x}_vs_{ha_y}', f'rmsd_{ha_x}{ha_y}'
    ]].dropna().drop_duplicates()
    
    # Merge q-values
    site_jsd_df = site_jsd_df.merge(
        jsd_pvals_df[['struct_site', 'q_value']], 
        on='struct_site', 
        how='left'
    )
    
    # Add significance flag
    site_jsd_df = site_jsd_df.assign(
        significant=lambda x: x['q_value'] < alpha
    )

    variant_selector = alt.selection_point(
        on="mouseover", empty=False, nearest=True, fields=["struct_site"], value=1
    )

    sorted_sites = natsorted(df['struct_site'].unique())
    base = alt.Chart(site_jsd_df).encode(
        alt.X(
            "struct_site:O",
            sort=sorted_sites, 
            title='Site',
            axis=alt.Axis(
                labelAngle=0,
                values=['1', '50', '100', '150', '200', '250', '300', '350', '400', '450', '500'],
                tickCount=11,
            )
        ),
        alt.Y(
            f'JS_{ha_x}_vs_{ha_y}:Q', 
            title=['Jensen-Shannon', 'Divergence'],
            axis=alt.Axis(
                grid=False
            ),
            scale=alt.Scale(domain=[0, 0.7])
        ),
        tooltip=[
            'struct_site', 
            f'{ha_x}_wt_aa', 
            f'{ha_y}_wt_aa', 
            alt.Tooltip(f'JS_{ha_x}_vs_{ha_y}', format='.4f'),
            alt.Tooltip(f'rmsd_{ha_x}{ha_y}', format='.2f'),
            alt.Tooltip('q_value', format='.4f'),
            'significant'
        ],
    ).properties(
        width=800,
        height=150
    )

    line = base.mark_line(opacity=0.5, stroke='#999999', size=1)
    
    # Points layer with conditional formatting based on hover and click
    points = base.mark_circle(filled=True).encode(
        size=alt.condition(
            variant_selector,
            alt.value(75),  # when selected
            alt.value(40)  # default
        ),
        color=alt.Color(
            'significant:N',
            title=['Significant', f'(FDR < {alpha})'],
            scale=alt.Scale(domain=[True, False], range=['#E15759', '#BAB0AC']),
            legend=alt.Legend(
                titleFontSize=14,
                labelFontSize=12
            )
        ),
        stroke=alt.condition(
            variant_selector,
            alt.value('black'),
            alt.value(None)
        ),
        strokeWidth=alt.condition(
            variant_selector,
            alt.value(1),
            alt.value(0)
        ),
        opacity=alt.condition(
            variant_selector,
            alt.value(1),
            alt.value(0.75)
        )
    ).add_params(
        variant_selector
    )

    # Correlation between cell entry effects plot
    # Filter based on hover (only if nothing clicked) or click
    base_corr_chart = (alt.Chart(df)
        .mark_text(size=20)
        .encode(
            alt.X(
                f"{ha_x}_effect", 
                title=["Effect on cell entry", f"in {ha_x.upper()}"], 
                scale=alt.Scale(domain=[-6,1.5])
            ),
            alt.Y(
                f"{ha_y}_effect", 
                title=["Effect on cell entry", f"in {ha_y.upper()}"], 
                scale=alt.Scale(domain=[-6,1.5])
            ),
            alt.Text('mutant'),
            alt.Color('mutant_type',
                    title='Mutant type',
                    scale=alt.Scale(
                        domain=['Aromatic', 'Hydrophilic', 'Hydrophobic','Negative', 'Positive', 'Special'],
                        range=["#4e79a7","#f28e2c","#e15759","#76b7b2","#59a14f","#edc949"]
                    ),
                    legend=alt.Legend(
                        titleFontSize=16,
                        labelFontSize=13
                    )
            ),
            tooltip=['struct_site', 'mutant', f'{ha_x}_wt_aa', f'{ha_y}_wt_aa', 
                     f'{ha_x}_effect', f'{ha_x}_effect_std', 
                     f'{ha_y}_effect', f'{ha_y}_effect_std',
                    f'JS_{ha_x}_vs_{ha_y}'],  
        )
        .transform_filter(
            variant_selector
        )
        .properties(
            height=150,
            width=150,
        )
    )

    # Vertical line at x = 0
    vline = alt.Chart(pd.DataFrame({'x': [0]})).mark_rule(color='gray',opacity=0.5,strokeDash=[2,4]).encode(x='x:Q')
    
    # Horizontal line at y = 0
    hline = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(color='gray',opacity=0.5,strokeDash=[2,4]).encode(y='y:Q')
    
    corr_chart = vline + hline + base_corr_chart

    # density plot
    density = alt.Chart(
        site_jsd_df
    ).transform_density(
        density=f'JS_{ha_x}_vs_{ha_y}',
        bandwidth=0.02,
        extent=[0,1],
        counts=True,
        steps=200
    ).mark_area(opacity=1, color='#CCEBC5', stroke='black', strokeWidth=1).encode(
        alt.X('value:Q', title='Jensen-Shannon Divergence'),
        alt.Y('density:Q', title='Density').stack(None),
    ).properties(
        width=200,
        height=60
    )

    if shared_aai is not None:
        title_text = f'{ha_x.upper()} vs. {ha_y.upper()} ({shared_aai:.1f}% AAI)'
    else:
        title_text = f'{ha_x.upper()} vs. {ha_y.upper()}'

    # combine the bar and heatmaps
    combined_chart = alt.vconcat(
        (line + points), corr_chart, density
    ).resolve_scale(
        y='independent', 
        x='independent', 
        color='independent'
    )
    combined_chart = combined_chart.properties(
        title=alt.Title(title_text, 
        offset=0,
        fontSize=18,
        #subtitle=['Hover over sites to see mutation effects. Click to lock selection (double-click to clear).'],
        subtitleFontSize=16,
        anchor='middle'
        )
    )

    return combined_chart

chart = plot_jsd(
    js_df_h3_h5,
    jsd_with_pvals_h3_h5,
    'h3', 'h5',
    seq_identity
)
chart.display()
In [12]:
chart = plot_jsd(
    js_df_h3_h7,
    jsd_with_pvals_h3_h7,
    'h3', 'h7',
    seq_identity
)
chart.display()
In [13]:
chart = plot_jsd(
    js_df_h5_h7,
    jsd_with_pvals_h5_h7,
    'h5', 'h7',
    seq_identity
)
chart.display()
In [14]:
js_df_h3_h5[[
    'struct_site', '4o5n_aa', '4kwm_aa', 'h3_wt_aa', 'h5_wt_aa', 
    '4o5n_aa_chain', '4kwm_aa_chain',
    '4o5n_aa_pdb_site', '4kwm_aa_pdb_site', 'JS_h3_vs_h5', 'rmsd_h3h5', '4o5n_aa_RSA', '4kwm_aa_RSA', '6ii9_aa_RSA'
]].drop_duplicates().reset_index(drop=True).to_csv(
    '../results/divergence/h3_h5_divergence.csv', index=False
)

js_df_h3_h7[[
    'struct_site', '4o5n_aa', '6ii9_aa', 'h3_wt_aa', 'h7_wt_aa', 
    '4o5n_aa_chain', '6ii9_aa_chain',
    '4o5n_aa_pdb_site', '6ii9_aa_pdb_site', 'JS_h3_vs_h7', 'rmsd_h3h7', '4o5n_aa_RSA', '4kwm_aa_RSA', '6ii9_aa_RSA'
]].drop_duplicates().reset_index(drop=True).to_csv(
    '../results/divergence/h3_h7_divergence.csv', index=False
)

js_df_h5_h7[[
    'struct_site', '4kwm_aa', '6ii9_aa', 'h5_wt_aa', 'h7_wt_aa',
    '4kwm_aa_chain', '6ii9_aa_chain', 
    '4kwm_aa_pdb_site', '6ii9_aa_pdb_site', 'JS_h5_vs_h7', 'rmsd_h5h7', '4o5n_aa_RSA', '4kwm_aa_RSA', '6ii9_aa_RSA'
]].drop_duplicates().reset_index(drop=True).to_csv(
    '../results/divergence/h5_h7_divergence.csv', index=False
)

H7 2'6 vs. H7 2'3¶

In [15]:
def read_and_filter_data(
    path, 
    effect_std_filter=2,
    times_seen_filter=2,
    n_selections_filter=2,
    clip_effect=-5 
):
    print(f'Reading data from {path}')
    print(
        f"Filtering for:\n"
        f"  effect_std <= {effect_std_filter}\n"
        f"  times_seen >= {times_seen_filter}\n"
        f"  n_selections >= {n_selections_filter}"
    )
    print(f"Clipping effect values at {clip_effect}")

    df = pd.read_csv(path).query(
        'effect_std <= @effect_std_filter and \
        times_seen >= @times_seen_filter and \
        n_selections >= @n_selections_filter'
    ).query(
        'mutant not in ["*", "-"]' # don't want stop codons/indels
    )

    df['site'] = df['site'].astype(str)
    df['effect'] = df['effect'].clip(-5)

    df = pd.concat([
        df,
        df[['site', 'wildtype']].drop_duplicates().assign(
            mutant=lambda x: x['wildtype'],
            effect=0.0,
            effect_std=0.0,
            times_seen=np.nan,
            n_selections=np.nan
        ) # add wildtype sites with zero effect
    ], ignore_index=True).sort_values(['site', 'mutant']).reset_index(drop=True)
    
    return df
In [16]:
h7_23_df = read_and_filter_data('../data/cell_entry_effects/293_2-3_entry_func_effects.csv')[[
    'site', 'wildtype', 'mutant', 'effect', 'effect_std'
]].rename(
    columns={
        'site': 'struct_site',
        'wildtype': 'h7_2-3_wt_aa',
        'mutant': 'mutant',
        'effect': 'h7_2-3_effect',
        'effect_std': 'h7_2-3_effect_std'
    }
)
h7_26_df = read_and_filter_data('../data/cell_entry_effects/293_2-6_entry_func_effects.csv')[[
    'site', 'wildtype', 'mutant', 'effect', 'effect_std'
]].rename(
    columns={
        'site': 'struct_site',
        'wildtype': 'h7_2-6_wt_aa',
        'mutant': 'mutant',
        'effect': 'h7_2-6_effect',
        'effect_std': 'h7_2-6_effect_std'
    }
)

h7_23_26_df = pd.merge(
    h7_23_df,
    h7_26_df,
    left_on=['struct_site', 'h7_2-3_wt_aa', 'mutant'],
    right_on=['struct_site', 'h7_2-6_wt_aa', 'mutant'],
).assign(
    **{'rmsd_h7_2-3h7_2-6': 0}
)

h7_23_26_df.head()
Reading data from ../data/cell_entry_effects/293_2-3_entry_func_effects.csv
Filtering for:
  effect_std <= 2
  times_seen >= 2
  n_selections >= 2
Clipping effect values at -5
Reading data from ../data/cell_entry_effects/293_2-6_entry_func_effects.csv
Filtering for:
  effect_std <= 2
  times_seen >= 2
  n_selections >= 2
Clipping effect values at -5
Out[16]:
struct_site h7_2-3_wt_aa mutant h7_2-3_effect h7_2-3_effect_std h7_2-6_wt_aa h7_2-6_effect h7_2-6_effect_std rmsd_h7_2-3h7_2-6
0 100 G A -0.00205 0.9766 G -1.319 0.8159 0
1 100 G C -3.91300 0.0130 G -4.422 0.0000 0
2 100 G D -4.76800 0.0000 G -4.937 0.0000 0
3 100 G G 0.00000 0.0000 G 0.000 0.0000 0
4 100 G H -4.64600 0.0000 G -4.800 0.0000 0
In [17]:
js_df_h7_23_26 = compute_js_divergence_per_site(h7_23_26_df, 'h7_2-3', 'h7_2-6', min_mutations=10)
In [18]:
# Compute JSD with null distributions for each comparison
jsd_with_pvals_h7_23_26 = compute_jsd_with_null(
    js_df_h7_23_26,
    'h7_2-3', 'h7_2-6',
    min_mutations=10,
    n_bootstrap=1000,
    jsd_threshold=0.02
)

jsd_with_pvals_h7_23_26 = apply_fdr_with_threshold(jsd_with_pvals_h7_23_26)

# Report significant sites as fractions (out of ALL sites with JSD measurements)
print("Significant sites (H7 2-3 vs H7 2-6, q < 0.1):")
total_h7_23_26 = len(jsd_with_pvals_h7_23_26)
sig_h7_23_26 = (jsd_with_pvals_h7_23_26['q_value'] < 0.1).sum()
print(f"  {sig_h7_23_26} / {total_h7_23_26} sites ({sig_h7_23_26/total_h7_23_26:.2%})")
Significant sites (H7 2-3 vs H7 2-6, q < 0.1):
  0 / 492 sites (0.00%)
In [19]:
chart = plot_jsd(
    js_df_h7_23_26,
    jsd_with_pvals_h7_23_26,
    'h7_2-3', 'h7_2-6'
)
chart.display()
In [20]:
def plot_ridgeline_density(dfs_dict, x_col_template='JS_{ha_x}_vs_{ha_y}', 
                           bandwidth=0.02, extent=[0,0.65], 
                           colors=None, width=200, height=400,
                           overlap=2.5, label_mapping=None):
    """
    Plot ridgeline (joyplot) density plots for multiple dataframes.
    
    Parameters:
    -----------
    dfs_dict : dict
        Dictionary where keys are comparison labels (e.g., 'h3-h5', 'h3-h7') 
        and values are tuples of (df, ha_x, ha_y)
        Example: {'h3-h5': (js_df_h3_h5, 'h3', 'h5'), 
                  'h3-h7': (js_df_h3_h7, 'h3', 'h7')}
    x_col_template : str
        Template for column name with {ha_x} and {ha_y} placeholders
    bandwidth : float
        Bandwidth for density estimation
    extent : list
        [min, max] for density calculation
    colors : list or None
        List of colors for each comparison. If None, uses default color scheme
    width, height : int
        Dimensions of the plot
    overlap : float
        How much the ridges overlap (higher = more overlap)
    
    Returns:
    --------
    alt.Chart : Ridgeline density plot
    """
    import pandas as pd
    import altair as alt
    
    # Default color scheme if none provided
    if colors is None:
        colors = ['#8DD3C7', '#FFFFB3', '#BEBADA', '#FB8072', '#80B1D3', '#FDB462']
    
    # Combine all dataframes with a comparison label
    combined_data = []
    for i, (comparison, (df, ha_x, ha_y)) in enumerate(dfs_dict.items()):
        col_name = x_col_template.format(ha_x=ha_x, ha_y=ha_y)
        temp_df = df[[col_name]].copy()
        temp_df['comparison'] = comparison
        temp_df['value'] = temp_df[col_name]
        combined_data.append(temp_df[['value', 'comparison']])
    
    combined_df = pd.concat(combined_data, ignore_index=True)
    
    if label_mapping is not None:
            combined_df['comparison'] = combined_df['comparison'].map(label_mapping)
        
    # Calculate step size for ridgeline spacing
    step = height / (len(dfs_dict) * overlap)
    
    # Create the ridgeline plot
    ridgeline = alt.Chart(combined_df).transform_density(
        density='value',
        bandwidth=bandwidth,
        extent=extent,
        groupby=['comparison'],
        steps=200
    ).transform_calculate(
        # Offset each comparison vertically based on its order
        yvalue='datum.density'
    ).mark_area(
        opacity=1,
        stroke='black',
        strokeWidth=1,
        interpolate='monotone'
    ).encode(
        alt.X('value:Q', title='Jensen-Shannon Divergence'),
        alt.Y('density:Q', 
              title='Density',
              axis=None),
        alt.Row('comparison:N',
                title=None,
                header=alt.Header(labelAngle=0, labelAlign='left')),
        alt.Fill('comparison:N',
                 legend=None,
                 scale=alt.Scale(range=colors[:len(dfs_dict)]))
    ).properties(
        width=width,
        height=step,
        bounds='flush'
    ).configure_facet(
        spacing=-(step * (overlap - 1))
    ).configure_view(
        stroke=None
    ).configure_header(
    labelFontSize=14
)
    
    return ridgeline

# Example usage:
dfs_to_plot = {
    'h3-h5': (js_df_h3_h5, 'h3', 'h5'),
    'h3-h7': (js_df_h3_h7, 'h3', 'h7'),
    'h5-h7': (js_df_h5_h7, 'h5', 'h7'),
    'h7_2-3-h7_2-6': (js_df_h7_23_26, 'h7_2-3', 'h7_2-6')
}

plot_ridgeline_density(
    dfs_to_plot, 
    label_mapping={
        'h3-h5': 'H3 vs. H5',
        'h3-h7': 'H3 vs. H7',
        'h5-h7': 'H5 vs. H7',
        'h7_2-3-h7_2-6': ['H7 (a2,3) vs.', 'H7 (a2,6)']
    }
).display()

Examples of mutation effect correlations¶

In [21]:
def plot_correlation(df, ha_x, ha_y, site, decimal_places=2):
    amino_acid_classification = {
        'F': 'Aromatic', 'Y': 'Aromatic', 'W': 'Aromatic',
        'N': 'Hydrophilic', 'Q': 'Hydrophilic', 'S': 'Hydrophilic', 'T': 'Hydrophilic',
        'A': 'Hydrophobic', 'V': 'Hydrophobic', 'I': 'Hydrophobic', 'L': 'Hydrophobic', 'M': 'Hydrophobic',
        'D': 'Negative', 'E': 'Negative',
        'R': 'Positive', 'H': 'Positive', 'K': 'Positive',
        'C': 'Special', 'G': 'Special', 'P': 'Special'
    }
    df['struct_site'] = df['struct_site'].astype(str)

    df = df.assign(
        mutant_type=lambda x: x['mutant'].map(amino_acid_classification)
    ).query(f'struct_site == "{site}"')

    jsd = df[f'JS_{ha_x}_vs_{ha_y}'].unique()[0]

    base_corr_chart = (alt.Chart(df.query(f'struct_site == "{site}"'))
        .mark_text(size=20)
        .encode(
            alt.X(
                f"{ha_x}_effect", 
                title=["Effect on cell entry", f"in {ha_x.upper()}"], 
                scale=alt.Scale(domain=[-6,1.5])
            ),
            alt.Y(
                f"{ha_y}_effect", 
                title=["Effect on cell entry", f"in {ha_y.upper()}"], 
                scale=alt.Scale(domain=[-6,1.5])
            ),
            alt.Text('mutant'),
            alt.Color('mutant_type',
                    title='Mutant type',
                    scale=alt.Scale(
                        domain=['Aromatic', 'Hydrophilic', 'Hydrophobic','Negative', 'Positive', 'Special'],
                        range=["#4e79a7","#f28e2c","#e15759","#76b7b2","#59a14f","#edc949"]
                    ),
                    legend=alt.Legend(
                        titleFontSize=16,
                        labelFontSize=13
                    )
            ),
            tooltip=['struct_site', 'mutant', f'{ha_x}_wt_aa', f'{ha_y}_wt_aa', 
                     f'{ha_x}_effect', f'{ha_x}_effect_std', 
                     f'{ha_y}_effect', f'{ha_y}_effect_std',
                    f'JS_{ha_x}_vs_{ha_y}'],  
        ).properties(
            height=125,
            width=125,
        )
    )

    # Vertical line at x = 0
    vline = alt.Chart(pd.DataFrame({'x': [0]})).mark_rule(color='gray',opacity=0.5,strokeDash=[2,4]).encode(x='x:Q')
    
    # Horizontal line at y = 0
    hline = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(color='gray',opacity=0.5,strokeDash=[2,4]).encode(y='y:Q')
    
    corr_chart = (vline + hline + base_corr_chart).properties(
        title=alt.Title([f'Site {site}', f'Divergence = {jsd:.{decimal_places}f}'], 
        offset=0,
        fontSize=16,
        anchor='middle'
        )
    )
    return corr_chart
In [22]:
(
    plot_correlation(js_df_h3_h5, 'h3', 'h5', site='86') |
    plot_correlation(js_df_h3_h5, 'h3', 'h5', site='97', decimal_places=4) |
    plot_correlation(js_df_h3_h5, 'h3', 'h5', site='198') |
    plot_correlation(js_df_h3_h5, 'h3', 'h5', site='241')
).display()
In [23]:
(
    plot_correlation(js_df_h3_h5, 'h3', 'h5', site='86') |
    plot_correlation(js_df_h3_h7, 'h3', 'h7', site='86') |
    plot_correlation(js_df_h5_h7, 'h5', 'h7', site='86')
).display()
In [24]:
(
    plot_correlation(js_df_h3_h5, 'h3', 'h5', site='173') |
    plot_correlation(js_df_h3_h7, 'h3', 'h7', site='173') |
    plot_correlation(js_df_h5_h7, 'h5', 'h7', site='173')
).display()
In [25]:
(
    plot_correlation(js_df_h3_h5, 'h3', 'h5', site='178') |
    plot_correlation(js_df_h3_h7, 'h3', 'h7', site='178') |
    plot_correlation(js_df_h5_h7, 'h5', 'h7', site='178')
).display()

(
    plot_correlation(js_df_h3_h5, 'h3', 'h5', site='123') |
    plot_correlation(js_df_h3_h7, 'h3', 'h7', site='123') |
    plot_correlation(js_df_h5_h7, 'h5', 'h7', site='123')
).display()

(
    plot_correlation(js_df_h3_h5, 'h3', 'h5', site='176') |
    plot_correlation(js_df_h3_h7, 'h3', 'h7', site='176') |
    plot_correlation(js_df_h5_h7, 'h5', 'h7', site='176')
).display()

# H3 forms H bonds at 178, 123, 176, and 211. 
# H5 and H7 do not form any H bonds in this region, and therefore tolerate many more amino acids.